import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import sys
from matplotlib.patches import Patch
import matplotlib.lines as mlines
sys.path.insert(1,'/REDACTED/fairness/code/config')
from configureColors import *
sys.path.insert(1,'/REDACTED/fairness/code/utilities')
import weightedBinscatter as wb
import estimators as et
#sys.path.append('packages')
#import binscatter
import matplotlib.font_manager as fm
import matplotlib
import statsmodels.formula.api as smf
from statsmodels.iolib.smpickle import load_pickle
import joblib
from trajectoryPlotters import *

# set plot defaults
plt.style.use('/REDACTED/fairness/code/config/fairness.mplstyle')
fe = fm.FontEntry(
    fname='/REDACTED/fairness/code/config/fonts/cmunrm.ttf',
    name='latex')
fm.fontManager.ttflist.insert(0, fe)
matplotlib.rcParams['font.family'] = fe.name

colorsBNB = cdictBlackNonBlack()
colorsENE = cdictEITCNon()
cb = colorsBNB['Black']['color']
ab = colorsBNB['Black']['alpha']
cnb = colorsBNB['non-Black']['color']
anb = colorsBNB['non-Black']['alpha']
ce = colorsENE['eitc']['color']
cne = colorsENE['non']['color']
ae = colorsENE['eitc']['alpha']
ane = colorsENE['non']['alpha']

## Linear Estimator
def regEstimate(dataset, pbvarb = 'predicted_prob_black', outcome = 'audited'):
	model = sm.OLS.from_formula(outcome+' ~ '+pbvarb, data = dataset).fit(cov_type = 'HC1')
	coef = model.params[pbvarb]
	se = model.bse[pbvarb]
	black_aud_rate = model.params[0] + model.params[1]
	non_black_aud_rate = model.params[0] 
	return coef, se, black_aud_rate, non_black_aud_rate

## Probabilistic Estimator 
def chenEstimate(dataset, pbvarb = 'predicted_prob_black', outcome = 'audited'):
	black_aud_rate = (dataset[pbvarb]*dataset[outcome]).sum()/(dataset[pbvarb]).sum()
	non_black_aud_rate = ((1-dataset[pbvarb])*dataset[outcome]).sum()/(1-dataset[pbvarb]).sum()
	est =  black_aud_rate - non_black_aud_rate
	return est, black_aud_rate, non_black_aud_rate

## Get standard errors
def getSEs(dataset,pbvarb='predicted_prob_black',outcome='audited', seReg = 0.0):
	seMultiplier=getSEMultiplier(dataset,pbvarb)
	#seReg = regEstimate(dataset,pbvarb,outcome)[1]
	seChen = seReg*seMultiplier
	return seReg, seChen

def getSEMultiplier(dataset,pbvarb='predicted_prob_black'):
	return np.sqrt(dataset[pbvarb].var()/(dataset[pbvarb].mean()*(1-dataset[pbvarb].mean())))

### OUTPUT

out = '/REDACTED/'

### OP AUDIT
dataBISG = pd.read_csv("/REDACTED/data/final/individualBISG2014_full_final.csv")
dataBISG= dataBISG[~dataBISG.predicted_prob_black.isna()]
data = dataBISG.copy()

# North Carolina data and weights
ncdata = pd.read_csv("REDACTED")
ncweights = pd.read_csv("REDACTED")

# merge together and subset to those that have ground truth race
ncdata.black_ind.isnull().sum()
ncdata = ncdata[ncdata.black_ind.notna()]
ncdata = ncdata[['taxpayer_id', 'black_ind']]


nc_BISG = data[data['state']=="NC"]
nc_merged = pd.merge(ncdata, nc_BISG, on = 'taxpayer_id')

ncweights = ncweights[['taxpayer_id', 'uswgt']]
nc_final = pd.merge(nc_merged, ncweights, on = 'taxpayer_id')

# separate into ground truth black and ground truth non-black datasets
data_black = nc_final[nc_final.black_ind == 1]
data_nonblack = nc_final[nc_final.black_ind == 0]
######################
##### Figure 1 (left): Distribution and Calibration of Race Imputation
######################

# for true Black, plot distribution of BIFSG scores
data_black.predicted_prob_black.mean()
data_black['pBlack_bucket'] = pd.cut(data_black.predicted_prob_black,bins=[i/20 for i in range(22)],labels=[i/20 for i in range(21)], right=False,include_lowest=True).astype(float)
pct_black_counts = data_black.groupby('pBlack_bucket').taxpayer_id.count().reset_index()
fig,ax=plt.subplots(figsize=(10,8))
plt.bar(pct_black_counts.pBlack_bucket,pct_black_counts.taxpayer_id/pct_black_counts.taxpayer_id.sum(),width=0.05)
plt.xlabel('BIFSG-Predicted Probability Black')
plt.ylabel('Population Share')
plt.savefig(out+'probBlack_hist_fw_true_black.png')
plt.close('all')

# for true non-Black, plot distribution of BIFSG scores
data_nonblack.predicted_prob_black.mean()
data_nonblack['pBlack_bucket'] = pd.cut(data_nonblack.predicted_prob_black,bins=[i/20 for i in range(22)],labels=[i/20 for i in range(21)], right=False,include_lowest=True).astype(float)
pct_black_counts = data_nonblack.groupby('pBlack_bucket').taxpayer_id.count().reset_index()
fig,ax=plt.subplots(figsize=(10,8))
plt.bar(pct_black_counts.pBlack_bucket,pct_black_counts.taxpayer_id/pct_black_counts.taxpayer_id.sum(),width=0.05)
plt.xlabel('BIFSG-Predicted Probability Black')
plt.ylabel('Population Share')
plt.savefig(out+'probBlack_hist_fw_true_nonblack.png')
plt.close('all')